from typing import List
from typing import Optional

from pydantic import Field

from processors.providers.api_processor_interface import ApiProcessorInterface
from processors.providers.api_processor_interface import ApiProcessorSchema
from processors.providers.replicate.utils import fetch_data_from_api
from common.utils.utils import get_key_or_raise


class Blip2Input(ApiProcessorSchema):
    image: Field(
        ..., description='Input image to query or caption',
        widget='datasource',
    )
    generate_caption: Optional[bool] = Field(
        False, description='Generate caption for image',
    )
    query: Optional[str] = Field(
        '', description='Question to ask about this image. Leave blank for captioning',
    )


class Blip2Configuration(ApiProcessorSchema):
    use_nucleus_sampling: Optional[bool] = Field(
        False, description='Use nucleus sampling',
    )
    temperature: Optional[float] = Field(
        0.7, description='Temperature for use with nucleus sampling (minimum: 0.5; maximum: 1)',
    )
    version: str = Field(
        '4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608', description='Model version',
    )
    sync_mode: Optional[bool] = Field(
        False, description='Run in synchronous mode',
    )


class Blip2Output(ApiProcessorSchema):
    generations: List[dict] = Field(
        default=[], description='The completions generated by the model.',
    )
    _api_response: dict = Field(
        default={}, description='The raw response from the API.',
    )


class Blip2(ApiProcessorInterface[Blip2Input, Blip2Output, Blip2Configuration]):
    def name() -> str:
        return 'replicate/blip2'

    def process(self) -> dict:
        _env = self._env
        replicate_api_key = get_key_or_raise(
            _env, 'replicate_api_key', 'No replicate_api_key found in _env',
        )

        url = 'https://api.replicate.com/v1/predictions'
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {replicate_api_key}',
        }
        configuration = self._config.dict().pop('sync_mode')
        api_params = {
            'version': self._config.version,
            'input': {**configuration.pop('version'), **input.dict()},
        }
        api_params.pop('_env', None)

        api_response = fetch_data_from_api(url, api_params, headers)

        if api_response.ok:
            json_api_response = api_response.json()
            result = {
                'async': json_api_response['urls'], '_response': {
                    '_api_response': json_api_response,
                },
            }
            return result
        else:
            raise Exception(f'Error calling OpenAI API: {api_response.text}')
